femtocrux 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
femtocrux/VERSION CHANGED
@@ -1 +1 @@
1
- 0.3.0
1
+ 0.4.1
femtocrux/__init__.py CHANGED
@@ -1,6 +1,12 @@
1
- from .client.client import CompilerClient, TFLiteModel, FQIRModel
1
+ from .client.client import CompilerClient, TFLiteModel, FQIRModel, ManagedCompilerClient
2
2
  from .version import __version__
3
3
 
4
4
  # PEP 8 definiton of public API
5
5
  # https://peps.python.org/pep-0008/#public-and-internal-interfaces
6
- __all__ = ["CompilerClient", "TFLiteModel", "FQIRModel", "__version__"]
6
+ __all__ = [
7
+ "CompilerClient",
8
+ "TFLiteModel",
9
+ "FQIRModel",
10
+ "__version__",
11
+ "ManagedCompilerClient",
12
+ ]
@@ -12,6 +12,7 @@ import queue
12
12
  import sys
13
13
  import time
14
14
  from typing import Any, List, Tuple, Union
15
+ from contextlib import contextmanager
15
16
 
16
17
  from fmot.fqir import GraphProto
17
18
 
@@ -25,7 +26,7 @@ import femtocrux.grpc.compiler_service_pb2_grpc as cs_pb2_grpc
25
26
  __docker_registry__ = "ghcr.io"
26
27
 
27
28
 
28
- def __get_docker_image_name__() -> str:
29
+ def _get_docker_image_name() -> str:
29
30
  """
30
31
  Returns the docker image name. For testing, override with the
31
32
  FEMTOCRUX_SERVER_IMAGE_NAME environment variable.
@@ -46,11 +47,11 @@ def __get_docker_image_name__() -> str:
46
47
  return remote_image_name
47
48
 
48
49
 
49
- __docker_image_name__ = __get_docker_image_name__()
50
+ __docker_image_name__ = _get_docker_image_name()
50
51
 
51
52
 
52
53
  # Set up logging
53
- def __init_logger__():
54
+ def _init_logger():
54
55
  """Init a basic logger to stderr."""
55
56
  logger = logging.getLogger(__name__)
56
57
  logger.setLevel(logging.INFO)
@@ -61,10 +62,10 @@ def __init_logger__():
61
62
  return logger
62
63
 
63
64
 
64
- logger = __init_logger__()
65
+ logger = _init_logger()
65
66
 
66
67
 
67
- def __env_var_to_bool__(varname: str, default: bool = False) -> bool:
68
+ def _env_var_to_bool(varname: str, default: bool = False) -> bool:
68
69
  """Parse an environment varaible as a boolean."""
69
70
  try:
70
71
  value = os.environ[varname]
@@ -90,24 +91,22 @@ class Model:
90
91
  :class:`~femtocrux.client.client.TFLiteModel`.
91
92
  """
92
93
 
93
- def __get_message__(self, options: dict = {}) -> cs_pb2.model:
94
+ def _get_message(self, options: dict = {}) -> cs_pb2.model:
94
95
  # Format the options
95
96
  options_struct = google.protobuf.struct_pb2.Struct()
96
97
  options_struct.update(options)
97
98
 
98
99
  # Construct the model with IR
99
- return cs_pb2.model(
100
- **{self.__ir_name__: self.__get_ir__()}, options=options_struct
101
- )
100
+ return cs_pb2.model(**{self._ir_name: self._get_ir()}, options=options_struct)
102
101
 
103
102
  @property
104
- def __ir_name__(self) -> str:
103
+ def _ir_name(self) -> str:
105
104
  """
106
105
  Subclass overrides this to tell which 'ir' field is being returned.
107
106
  """
108
107
  return NotImplementedError("Subclass must override with IR type.")
109
108
 
110
- def __get_ir__(self) -> Tuple[str, Any]:
109
+ def _get_ir(self) -> Tuple[str, Any]:
111
110
  """
112
111
  Subclass overrides this to implement the 'ir' field of the model's
113
112
  grpc message.
@@ -135,10 +134,10 @@ class FQIRModel(Model):
135
134
  sequence_dim: int = None
136
135
 
137
136
  @property
138
- def __ir_name__(self) -> str:
137
+ def _ir_name(self) -> str:
139
138
  return "fqir"
140
139
 
141
- def __get_ir__(self) -> Any:
140
+ def _get_ir(self) -> Any:
142
141
  # Serialize FQIR via pickle
143
142
  model = pickle.dumps(self.graph_proto)
144
143
 
@@ -167,10 +166,10 @@ class TFLiteModel(Model):
167
166
  signature_name: str = None
168
167
 
169
168
  @property
170
- def __ir_name__(self) -> str:
169
+ def _ir_name(self) -> str:
171
170
  return "tflite"
172
171
 
173
- def __get_ir__(self) -> Any:
172
+ def _get_ir(self) -> Any:
174
173
  return cs_pb2.tflite(model=self.flatbuffer, signature_name=self.signature_name)
175
174
 
176
175
 
@@ -185,31 +184,31 @@ class Simulator:
185
184
 
186
185
  # Create an event stream fed by a queue
187
186
  self.request_queue = queue.SimpleQueue()
188
- request_iterator = iter(self.request_queue.get, self.__request_sentinel__)
189
- self.response_iterator = client.__simulate__(request_iterator)
187
+ request_iterator = iter(self.request_queue.get, self._request_sentinel)
188
+ self.response_iterator = client._simulate(request_iterator)
190
189
 
191
190
  # Compile the model with the first message
192
- model_msg = model.__get_message__(options)
191
+ model_msg = model._get_message(options)
193
192
  simulation_start_msg = cs_pb2.simulation_input(model=model_msg)
194
- self.__send_request__(simulation_start_msg)
193
+ self._send_request(simulation_start_msg)
195
194
 
196
195
  # Check compilation status
197
- self.__get_response__()
196
+ self._get_response()
198
197
 
199
198
  def __del__(self):
200
199
  """Close any open streams."""
201
- self.__send_request__(self.__request_sentinel__)
200
+ self._send_request(self._request_sentinel)
202
201
 
203
- def __send_request__(self, msg):
202
+ def _send_request(self, msg):
204
203
  self.request_queue.put(msg)
205
204
 
206
- def __get_response__(self):
205
+ def _get_response(self):
207
206
  response = next(self.response_iterator)
208
- self.client.__check_status__(response.status)
207
+ self.client._check_status(response.status)
209
208
  return response
210
209
 
211
210
  @property
212
- def __request_sentinel__(self) -> Any:
211
+ def _request_sentinel(self) -> Any:
213
212
  """Sentinel value to close the request queue."""
214
213
  return None
215
214
 
@@ -218,7 +217,7 @@ class Simulator:
218
217
  inputs: Union[np.array, Iterable[np.array]],
219
218
  quantize_inputs: bool = False,
220
219
  dequantize_outputs: bool = False,
221
- sim_duration: float = None,
220
+ input_period: float = None,
222
221
  ) -> List[np.array]:
223
222
  """
224
223
  Simulates the model on the given inputs.
@@ -243,7 +242,8 @@ class Simulator:
243
242
  static power consumption. For example, in a streaming model, this should be the
244
243
  time elapsed between model invocations. By default, this is just the estimated
245
244
  latency of the model.
246
- :type sim_duration: float, optional
245
+ :type input_period (float, optional): Duration between each input in a sequence,
246
+ in seconds.
247
247
 
248
248
  :rtype: list
249
249
  :return: The output tensors.
@@ -260,11 +260,11 @@ class Simulator:
260
260
  data=[numpy_to_ndarray(x) for x in inputs],
261
261
  quantize_inputs=quantize_inputs,
262
262
  dequantize_outputs=dequantize_outputs,
263
- sim_duration=sim_duration,
263
+ input_period=input_period,
264
264
  )
265
265
  )
266
- self.__send_request__(simulation_request)
267
- response = self.__get_response__()
266
+ self._send_request(simulation_request)
267
+ response = self._get_response()
268
268
 
269
269
  return [ndarray_to_numpy(x) for x in response.data], response.report
270
270
 
@@ -279,21 +279,21 @@ class CompilerClientImpl:
279
279
  def __init__(self, channel, stub):
280
280
  self.channel = channel
281
281
  self.stub = stub
282
- self.__check_version__()
282
+ self._check_version()
283
283
 
284
- def __check_status__(self, status):
284
+ def _check_status(self, status):
285
285
  """Check a status response, raising an exception if unsuccessful."""
286
286
  if not status.success:
287
287
  raise RuntimeError(
288
288
  "Client received error from compiler server:\n%s" % status.msg
289
289
  )
290
290
 
291
- def __check_version__(self):
291
+ def _check_version(self):
292
292
  """Verify the server's version matches the client."""
293
293
 
294
294
  from femtocrux.version import __version__ as client_version
295
295
 
296
- server_version = self.__server_version__()
296
+ server_version = self._server_version()
297
297
  assert (
298
298
  client_version == server_version
299
299
  ), """
@@ -319,17 +319,17 @@ class CompilerClientImpl:
319
319
  :return: A zip archive of compiler artifacts.
320
320
  """
321
321
 
322
- response = self.stub.compile(model.__get_message__(options))
323
- self.__check_status__(response.status)
322
+ response = self.stub.compile(model._get_message(options))
323
+ self._check_status(response.status)
324
324
  return response.bitfile
325
325
 
326
- def __ping__(self, message: bytes) -> None:
326
+ def _ping(self, message: bytes) -> None:
327
327
  """Pings the server with a message."""
328
328
  response = self.stub.ping(cs_pb2.data(data=message))
329
329
  if response.data != message:
330
330
  raise RuntimeError("Server response does not match request data!")
331
331
 
332
- def __simulate__(self, in_stream: Iterable) -> Iterable:
332
+ def _simulate(self, in_stream: Iterable) -> Iterable:
333
333
  """Calls the 'simulator' bidirectional streaming RPC."""
334
334
  return self.stub.simulate(in_stream)
335
335
 
@@ -348,7 +348,7 @@ class CompilerClientImpl:
348
348
  """
349
349
  return Simulator(client=self, model=model, options=options)
350
350
 
351
- def __server_version__(self) -> str:
351
+ def _server_version(self) -> str:
352
352
  """Queries the femtocrux version running on the server."""
353
353
  response = self.stub.version(google.protobuf.empty_pb2.Empty())
354
354
  return response.version
@@ -373,24 +373,49 @@ class CompilerClient(CompilerClientImpl):
373
373
  self.container = None # For __del__
374
374
 
375
375
  # Start a new docker server
376
- self.container = self.__create_docker_server__(docker_kwargs)
377
- self.__wait_for_server_ready__()
378
- self.__init_network_info__(self.container)
376
+ self.container = self._create_docker_server(docker_kwargs)
377
+ self._wait_for_server_ready()
378
+ self._init_network_info(self.container)
379
379
 
380
380
  # Establish a connection to the server
381
- self.channel = self.__connect__()
381
+ self.channel = self._connect()
382
382
 
383
383
  # Initialize the client on this channel
384
384
  self.stub = cs_pb2_grpc.CompileStub(self.channel)
385
385
  super().__init__(self.channel, self.stub)
386
386
 
387
+ @property
388
+ def status(self):
389
+ if self.container is not None:
390
+ return self.container.status
391
+ else:
392
+ return "exited"
393
+
394
+ @property
395
+ def name(self):
396
+ if self.container is not None:
397
+ return self.container.name
398
+ else:
399
+ return None
400
+
401
+ def close(self):
402
+ if self.container is not None:
403
+ try:
404
+ self.container.stop()
405
+ except Exception as e:
406
+ logger.info(f"Image already closed... skipping close\n{e}")
407
+
387
408
  def __del__(self):
388
409
  """Reclaim system resources."""
389
410
  if self.container is not None:
390
- self.container.kill()
391
- self.container = None
411
+ cli = docker.DockerClient()
412
+ try:
413
+ container = cli.containers.get(self.name)
414
+ container.stop()
415
+ except docker.errors.NotFound:
416
+ pass
392
417
 
393
- def __get_docker_api_client__(self):
418
+ def _get_docker_api_client(self):
394
419
  """Get a client to the Docker daemon."""
395
420
  try:
396
421
  return docker.from_env()
@@ -400,7 +425,7 @@ class CompilerClient(CompilerClientImpl):
400
425
  Please ensure it is installed and running."""
401
426
  ) from exc
402
427
 
403
- def __init_network_info__(self, container):
428
+ def _init_network_info(self, container):
404
429
  """
405
430
  For local connections only.
406
431
 
@@ -424,7 +449,7 @@ class CompilerClient(CompilerClientImpl):
424
449
  socket = bound_sockets[0] # In case of multiple, take the first one
425
450
  self.__channel_port__ = socket["HostPort"]
426
451
 
427
- def __connect__(self) -> Any:
452
+ def _connect(self) -> Any:
428
453
  """Establishes a gRPC connection to the server."""
429
454
 
430
455
  # Open a gRPC channel to the server
@@ -467,20 +492,20 @@ class CompilerClient(CompilerClientImpl):
467
492
  return self.__channel_port__
468
493
 
469
494
  @property
470
- def __container_port__(self) -> int:
495
+ def _container_port(self) -> int:
471
496
  """Port used inside the container."""
472
497
  return 50051
473
498
 
474
499
  @property
475
- def __container_label__(self) -> str:
500
+ def _container_label(self) -> str:
476
501
  """Label attached to identify containers started by this client."""
477
502
  return "femtocrux_server"
478
503
 
479
- def __get_unused_container_name__(self) -> str:
504
+ def _get_unused_container_name(self) -> str:
480
505
  """Get an unused container name."""
481
506
 
482
507
  # Search for an unused name
483
- client = self.__get_docker_api_client__()
508
+ client = self._get_docker_api_client()
484
509
  container_idx = 0
485
510
  while True:
486
511
  name = "femtocrux_server_%d" % container_idx
@@ -492,7 +517,7 @@ class CompilerClient(CompilerClientImpl):
492
517
 
493
518
  container_idx += 1
494
519
 
495
- def __pull_docker_image__(self):
520
+ def _pull_docker_image(self):
496
521
  """Pull the Docker image from remote."""
497
522
 
498
523
  logger.info(
@@ -506,7 +531,7 @@ class CompilerClient(CompilerClientImpl):
506
531
  )
507
532
 
508
533
  # Log in to Github
509
- client = self.__get_docker_api_client__()
534
+ client = self._get_docker_api_client()
510
535
  while True:
511
536
  # Get the password
512
537
  manual_pass = True
@@ -561,7 +586,7 @@ class CompilerClient(CompilerClientImpl):
561
586
 
562
587
  logger.info("Download completed.")
563
588
 
564
- def __create_docker_server__(
589
+ def _create_docker_server(
565
590
  self, docker_kwargs: dict[str, Any] = None
566
591
  ) -> docker.models.containers.Container:
567
592
  """
@@ -571,7 +596,7 @@ class CompilerClient(CompilerClientImpl):
571
596
  docker_kwargs = {}
572
597
 
573
598
  # Get a client for the Docker daemon
574
- client = self.__get_docker_api_client__()
599
+ client = self._get_docker_api_client()
575
600
 
576
601
  # Pull the image, if not available
577
602
  existing_image_names = [
@@ -583,7 +608,7 @@ class CompilerClient(CompilerClientImpl):
583
608
  image_not_found_msg = (
584
609
  "Failed to find the docker image %s locally." % __docker_image_name__
585
610
  )
586
- if not __env_var_to_bool__("FS_ALLOW_DOCKER_PULL", default=True):
611
+ if not _env_var_to_bool("FS_ALLOW_DOCKER_PULL", default=True):
587
612
  raise RuntimeError(
588
613
  """
589
614
  %s
@@ -594,29 +619,29 @@ class CompilerClient(CompilerClientImpl):
594
619
 
595
620
  # Pull the image from remote
596
621
  logger.info(image_not_found_msg)
597
- self.__pull_docker_image__()
622
+ self._pull_docker_image()
598
623
 
599
624
  # Bind a random port on the host to the container's gRPC port
600
- port_interface = {self.__container_port__: None}
601
- command = "--port %s" % self.__container_port__
625
+ port_interface = {self._container_port: None}
626
+ command = "--port %s" % self._container_port
602
627
 
603
628
  # Start a container running the server
604
629
  container = client.containers.run(
605
630
  __docker_image_name__,
606
631
  command=command, # Appends entrypoint with args
607
632
  detach=True,
608
- labels=[self.__container_label__],
633
+ labels=[self._container_label],
609
634
  stderr=True,
610
635
  stdout=True,
611
636
  ports=port_interface,
612
- name=self.__get_unused_container_name__(),
637
+ name=self._get_unused_container_name(),
613
638
  auto_remove=True,
614
639
  **docker_kwargs,
615
640
  )
616
641
 
617
642
  return container
618
643
 
619
- def __wait_for_server_ready__(self):
644
+ def _wait_for_server_ready(self):
620
645
  """
621
646
  Block until the Docker container is ready to handle requests.
622
647
  """
@@ -670,6 +695,15 @@ class CompilerClient(CompilerClientImpl):
670
695
  )
671
696
 
672
697
 
698
+ @contextmanager
699
+ def ManagedCompilerClient(docker_kwargs: dict[str, Any] = None) -> CompilerClient:
700
+ client = CompilerClient(docker_kwargs=docker_kwargs)
701
+ try:
702
+ yield client
703
+ finally:
704
+ client.close()
705
+
706
+
673
707
  if __name__ == "__main__":
674
708
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
675
709
  client = CompilerClient()
@@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
15
15
  from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
16
16
 
17
17
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x63ompiler_service.proto\x12\nfscompiler\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\"g\n\x04\x66qir\x12\r\n\x05model\x18\x01 \x01(\x0c\x12\x16\n\tbatch_dim\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x19\n\x0csequence_dim\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x0c\n\n_batch_dimB\x0f\n\r_sequence_dim\"G\n\x06tflite\x12\r\n\x05model\x18\x01 \x01(\x0c\x12\x1b\n\x0esignature_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x11\n\x0f_signature_name\"\x90\x01\n\x05model\x12 \n\x04\x66qir\x18\x01 \x01(\x0b\x32\x10.fscompiler.fqirH\x00\x12$\n\x06tflite\x18\x02 \x01(\x0b\x32\x12.fscompiler.tfliteH\x00\x12-\n\x07options\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructH\x01\x88\x01\x01\x42\x04\n\x02irB\n\n\x08_options\"&\n\x06status\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0b\n\x03msg\x18\x02 \x01(\t\"I\n\x12\x63ompiled_artifacts\x12\x0f\n\x07\x62itfile\x18\x01 \x01(\x0c\x12\"\n\x06status\x18\x02 \x01(\x0b\x32\x12.fscompiler.status\"&\n\x07ndarray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x12\r\n\x05shape\x18\x02 \x03(\x03\"\x14\n\x04\x64\x61ta\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x95\x01\n\x0fsimulation_data\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.fscompiler.ndarray\x12\x17\n\x0fquantize_inputs\x18\x02 \x01(\x08\x12\x1a\n\x12\x64\x65quantize_outputs\x18\x03 \x01(\x08\x12\x19\n\x0csim_duration\x18\x04 \x01(\x02H\x00\x88\x01\x01\x42\x0f\n\r_sim_duration\"t\n\x10simulation_input\x12\"\n\x05model\x18\x01 \x01(\x0b\x32\x11.fscompiler.modelH\x00\x12+\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1b.fscompiler.simulation_dataH\x00\x42\x0f\n\rmodel_or_data\"j\n\x11simulation_output\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.fscompiler.ndarray\x12\x0e\n\x06report\x18\x02 \x01(\t\x12\"\n\x06status\x18\x03 \x01(\x0b\x32\x12.fscompiler.status\"\x1f\n\x0cversion_info\x12\x0f\n\x07version\x18\x01 \x01(\t2\x85\x02\n\x07\x43ompile\x12>\n\x07\x63ompile\x12\x11.fscompiler.model\x1a\x1e.fscompiler.compiled_artifacts\"\x00\x12,\n\x04ping\x12\x10.fscompiler.data\x1a\x10.fscompiler.data\"\x00\x12M\n\x08simulate\x12\x1c.fscompiler.simulation_input\x1a\x1d.fscompiler.simulation_output\"\x00(\x01\x30\x01\x12=\n\x07version\x12\x16.google.protobuf.Empty\x1a\x18.fscompiler.version_info\"\x00\x62\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x63ompiler_service.proto\x12\nfscompiler\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\"g\n\x04\x66qir\x12\r\n\x05model\x18\x01 \x01(\x0c\x12\x16\n\tbatch_dim\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x19\n\x0csequence_dim\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x0c\n\n_batch_dimB\x0f\n\r_sequence_dim\"G\n\x06tflite\x12\r\n\x05model\x18\x01 \x01(\x0c\x12\x1b\n\x0esignature_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x11\n\x0f_signature_name\"\x90\x01\n\x05model\x12 \n\x04\x66qir\x18\x01 \x01(\x0b\x32\x10.fscompiler.fqirH\x00\x12$\n\x06tflite\x18\x02 \x01(\x0b\x32\x12.fscompiler.tfliteH\x00\x12-\n\x07options\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructH\x01\x88\x01\x01\x42\x04\n\x02irB\n\n\x08_options\"&\n\x06status\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0b\n\x03msg\x18\x02 \x01(\t\"I\n\x12\x63ompiled_artifacts\x12\x0f\n\x07\x62itfile\x18\x01 \x01(\x0c\x12\"\n\x06status\x18\x02 \x01(\x0b\x32\x12.fscompiler.status\"&\n\x07ndarray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x12\r\n\x05shape\x18\x02 \x03(\x03\"\x14\n\x04\x64\x61ta\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\xc1\x01\n\x0fsimulation_data\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.fscompiler.ndarray\x12\x17\n\x0fquantize_inputs\x18\x02 \x01(\x08\x12\x1a\n\x12\x64\x65quantize_outputs\x18\x03 \x01(\x08\x12\x19\n\x0csim_duration\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cinput_period\x18\x05 \x01(\x02H\x01\x88\x01\x01\x42\x0f\n\r_sim_durationB\x0f\n\r_input_period\"t\n\x10simulation_input\x12\"\n\x05model\x18\x01 \x01(\x0b\x32\x11.fscompiler.modelH\x00\x12+\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x1b.fscompiler.simulation_dataH\x00\x42\x0f\n\rmodel_or_data\"j\n\x11simulation_output\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.fscompiler.ndarray\x12\x0e\n\x06report\x18\x02 \x01(\t\x12\"\n\x06status\x18\x03 \x01(\x0b\x32\x12.fscompiler.status\"\x1f\n\x0cversion_info\x12\x0f\n\x07version\x18\x01 \x01(\t2\x85\x02\n\x07\x43ompile\x12>\n\x07\x63ompile\x12\x11.fscompiler.model\x1a\x1e.fscompiler.compiled_artifacts\"\x00\x12,\n\x04ping\x12\x10.fscompiler.data\x1a\x10.fscompiler.data\"\x00\x12M\n\x08simulate\x12\x1c.fscompiler.simulation_input\x1a\x1d.fscompiler.simulation_output\"\x00(\x01\x30\x01\x12=\n\x07version\x12\x16.google.protobuf.Empty\x1a\x18.fscompiler.version_info\"\x00\x62\x06proto3')
19
19
 
20
20
  _globals = globals()
21
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -37,13 +37,13 @@ if _descriptor._USE_C_DESCRIPTORS == False:
37
37
  _globals['_DATA']._serialized_start=577
38
38
  _globals['_DATA']._serialized_end=597
39
39
  _globals['_SIMULATION_DATA']._serialized_start=600
40
- _globals['_SIMULATION_DATA']._serialized_end=749
41
- _globals['_SIMULATION_INPUT']._serialized_start=751
42
- _globals['_SIMULATION_INPUT']._serialized_end=867
43
- _globals['_SIMULATION_OUTPUT']._serialized_start=869
44
- _globals['_SIMULATION_OUTPUT']._serialized_end=975
45
- _globals['_VERSION_INFO']._serialized_start=977
46
- _globals['_VERSION_INFO']._serialized_end=1008
47
- _globals['_COMPILE']._serialized_start=1011
48
- _globals['_COMPILE']._serialized_end=1272
40
+ _globals['_SIMULATION_DATA']._serialized_end=793
41
+ _globals['_SIMULATION_INPUT']._serialized_start=795
42
+ _globals['_SIMULATION_INPUT']._serialized_end=911
43
+ _globals['_SIMULATION_OUTPUT']._serialized_start=913
44
+ _globals['_SIMULATION_OUTPUT']._serialized_end=1019
45
+ _globals['_VERSION_INFO']._serialized_start=1021
46
+ _globals['_VERSION_INFO']._serialized_end=1052
47
+ _globals['_COMPILE']._serialized_start=1055
48
+ _globals['_COMPILE']._serialized_end=1316
49
49
  # @@protoc_insertion_point(module_scope)
@@ -1,7 +1,7 @@
1
1
  import argparse
2
2
  from collections.abc import Iterable
3
3
  import concurrent
4
- import google.protobuf
4
+ import google.protobuf.json_format
5
5
  import grpc
6
6
  import logging
7
7
  import pickle
@@ -36,7 +36,7 @@ class CompileServicer(cs_pb2_grpc.CompileServicer):
36
36
  self.logger.setLevel(logging.DEBUG)
37
37
  self.logger.info("Starting compile server.")
38
38
 
39
- def __get_fqir_compiler__(self, model: cs_pb2.model) -> CompilerFrontend:
39
+ def _get_fqir_compiler(self, model: cs_pb2.model) -> CompilerFrontend:
40
40
  """Get a Torch compiler from an FQIR a model message"""
41
41
  # Deserialize FQIR
42
42
  fqir = model.fqir
@@ -49,19 +49,19 @@ class CompileServicer(cs_pb2_grpc.CompileServicer):
49
49
  seq_dim=field_or_none(fqir, "sequence_dim"),
50
50
  )
51
51
 
52
- def __get_tflite_compiler__(self, model: cs_pb2.model) -> CompilerFrontend:
52
+ def _get_tflite_compiler(self, model: cs_pb2.model) -> CompilerFrontend:
53
53
  """Get a TFLite compiler from a model message"""
54
54
  tflite = model.tflite
55
55
  return TFLiteCompiler(
56
56
  tflite.model, signature=field_or_none(tflite, "signature_name")
57
57
  )
58
58
 
59
- def __compile_model__(self, model: cs_pb2.model, context) -> CompilerFrontend:
59
+ def _compile_model(self, model: cs_pb2.model, context) -> CompilerFrontend:
60
60
  """Compile a model, for simulation or bitfile generation."""
61
61
  # Get a compiler for the model
62
62
  model_type_map = {
63
- "fqir": self.__get_fqir_compiler__,
64
- "tflite": self.__get_tflite_compiler__,
63
+ "fqir": self._get_fqir_compiler,
64
+ "tflite": self._get_tflite_compiler,
65
65
  }
66
66
  model_type = model.WhichOneof("ir")
67
67
  compiler = model_type_map[model_type](model)
@@ -85,7 +85,7 @@ class CompileServicer(cs_pb2_grpc.CompileServicer):
85
85
 
86
86
  # Compile the model
87
87
  try:
88
- compiler = self.__compile_model__(model, context)
88
+ compiler = self._compile_model(model, context)
89
89
  bitfile = compiler.dump_bitfile(encrypt=True)
90
90
  except Exception as exc:
91
91
  msg = "Compiler raised exception:\n%s" % (format_exception_from_exc(exc))
@@ -120,7 +120,7 @@ class CompileServicer(cs_pb2_grpc.CompileServicer):
120
120
 
121
121
  # Attempt compilation
122
122
  try:
123
- compiler = self.__compile_model__(model_request.model, context)
123
+ compiler = self._compile_model(model_request.model, context)
124
124
  except Exception as exc:
125
125
  msg = "Compiler raised exception:\n%s" % (
126
126
  format_exception_from_exc(exc)
@@ -158,7 +158,7 @@ class CompileServicer(cs_pb2_grpc.CompileServicer):
158
158
  data,
159
159
  quantize_inputs=sim_data.quantize_inputs,
160
160
  dequantize_outputs=sim_data.dequantize_outputs,
161
- input_period=field_or_none(sim_data, "sim_duration"),
161
+ input_period=field_or_none(sim_data, "input_period"),
162
162
  )
163
163
  except Exception as exc:
164
164
  msg = "Simulator raised exception:\n%s" % (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: femtocrux
3
- Version: 0.3.0
3
+ Version: 0.4.1
4
4
  Summary: Femtosense Compiler
5
5
  Home-page: https://github.com/femtosense/femtocrux
6
6
  Author: Femtosense
@@ -37,7 +37,6 @@ The package itself is a thin Python wrapper communicating to the actual compiler
37
37
 
38
38
  Supported ML model representations:
39
39
  - Femtosense Quantized IR (FQIR)
40
- - Tensorflow Lite (TFLite)
41
40
 
42
41
  ## Installation
43
42
 
@@ -52,16 +51,51 @@ pip install femtocrux
52
51
 
53
52
  ## Basic Usage
54
53
 
54
+ ### `ManagedCompilerClient`
55
+ `ManagedCompilerClient` is the recommended way to use femtocrux within a context manager.
56
+
57
+ ```python
58
+ from femtocrux import ManagedCompilerClient, FQIRModel
59
+
60
+ # fqir_graph = ... # Assuming we have an FQIR graph using fmot
61
+ # inputs = ... # Assuming we have numpy array inputs to the model
62
+
63
+ fqir_model = FQIRModel(fqir_graph, batch_dim=0, sequence_dim=1)
64
+ with ManagedCompilerClient() as client:
65
+
66
+ # simulate execution, view power, energy, and latency metrics
67
+ simulator = client.simulate(fqir_model)
68
+ outputs, metrics = simulator.simulate(inputs)
69
+
70
+ # compile the model to a bitfile
71
+ bitstream = client.compile(fqir_model)
72
+ with open('my_bitfile.zip', 'wb') as f:
73
+ f.write(bitstream)
74
+ ```
75
+
76
+ The bitfile can be used to generate program files to run on the SPU, using [femtodriverpub](https://github.com/femtosense/femtodriverpub.git)
77
+
78
+ ### `CompilerClient`
79
+ `CompilerClient` is another interface to the compiler. This will be deprecated in future releases, so we recommend using `ManagedCompilerClient`.
80
+
55
81
  ```python
56
- # fqir_model = ... # Assuming we have an FQIRModel
57
- # inputs = ... # Assuming we have inputs for the model
82
+ from femtocrux import CompilerClient, FQIRModel
58
83
 
59
- from femtocrux import CompilerClient
84
+ # fqir_graph = ... # Assuming we have an FQIR graph using fmot
85
+ # inputs = ... # Assuming we have numpy array inputs to the model
86
+
87
+ fqir_model = FQIRModel(fqir_graph, batch_dim=0, sequence_dim=1)
60
88
  client = CompilerClient()
89
+
90
+ # simulate execution, view power, energy, and latency metrics
61
91
  simulator = client.simulate(fqir_model)
62
- outputs = simulator.simulate(inputs)
63
- ```
92
+ outputs, metrics = simulator.simulate(inputs)
64
93
 
65
- ## APIs
94
+ # compile the model to a bitfile
95
+ bitstream = client.compile(fqir_model)
96
+ with open('my_bitfile.zip', 'wb') as f:
97
+ f.write(bitstream)
66
98
 
67
- Coming soon...
99
+ # close the client (important to avoid background docker containers that continue running)
100
+ client.close()
101
+ ```
@@ -1,21 +1,21 @@
1
1
  femtocrux/ENV_REQUIREMENTS.sh,sha256=t_O1B4hJAMgxvH9gwp1qls6eVFmhSYBJe64KmuK_H-4,1389
2
2
  femtocrux/PY_REQUIREMENTS,sha256=HZOIodZiEjQF6kbKT1PxBU96fU8W7fypimUfIeJaoAQ,303
3
- femtocrux/VERSION,sha256=2RXMldbKj0euKXcT7UbU5cXZnd0p_Dxh4mO98wXytbA,6
4
- femtocrux/__init__.py,sha256=fQdrj_d4a2T4xIcKoTdTLokaZsDP5uAbW7c39YPzkX8,271
3
+ femtocrux/VERSION,sha256=9iGEzuh4fy9pQcggQaVyXU7cmqKT6-Xb9mRAboLsH-E,6
4
+ femtocrux/__init__.py,sha256=yIWd9I2PEXCn_PKIILAN3mkWeTf0tgtVualeTIHNxfQ,342
5
5
  femtocrux/version.py,sha256=uNg2kHxQo6oUN1ah7s9_85rCZVRoTHGPD1GAQPZW4lw,164
6
6
  femtocrux/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- femtocrux/client/client.py,sha256=VP6KsESnzUAI49tAsEFGzmWEK0_cnv0e4Mm2izOkOOU,21910
7
+ femtocrux/client/client.py,sha256=PNaXL_XTe_7_r14WWKmp1RnvAW-kvYzDXcimseA8sEo,22726
8
8
  femtocrux/grpc/__init__.py,sha256=uiMHQt5I2eAKJqI3Zh0h1Gm7cmPR4PbaGS71nCJQCGw,169
9
- femtocrux/grpc/compiler_service_pb2.py,sha256=XaawNGpFBdURP5JJScnzfA8JY-_J18vpVJ-Blzl5WF0,4339
9
+ femtocrux/grpc/compiler_service_pb2.py,sha256=RUBcT-S2sKr5kveHQQ7EMyAe5rGU_8JZVbhKAG693Fg,4424
10
10
  femtocrux/grpc/compiler_service_pb2_grpc.py,sha256=L9EQFYMFVACngUrAEmSJDgUTSPSi3rDoG7TERg-pFq4,7002
11
11
  femtocrux/server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  femtocrux/server/exceptions.py,sha256=lI6n471n5QKf5G3aL_1kuBVEItD-jBgithVVpPDwNYc,609
13
13
  femtocrux/server/healthcheck.py,sha256=ehqAwnv0D0zpy-AUZAPwv8rp874DZCwUmP8nzdXzZvI,1565
14
- femtocrux/server/server.py,sha256=KAH6paHEQrdw9PIv7GxSc_O2BBAKQHys3s8Ez5NkY1I,7928
14
+ femtocrux/server/server.py,sha256=tmXVleZQB59oFzdmut3na4NnDvr0gmxphXF3N3MQx6I,7919
15
15
  femtocrux/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  femtocrux/util/utils.py,sha256=FZ8cssDom4B3FDbVU_ew4Cf3wOWjo2w1jwcbnLzoYnM,1003
17
- femtocrux-0.3.0.dist-info/LICENSE,sha256=eN9ZI1xHjUmFvN3TEeop5kBGXRUBfbsl55KBNBYYFqI,36
18
- femtocrux-0.3.0.dist-info/METADATA,sha256=XiRTQKnAplRmDKUGlWjh9XlhbqiyyKFTuXAoZI2Auzg,2109
19
- femtocrux-0.3.0.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
20
- femtocrux-0.3.0.dist-info/top_level.txt,sha256=BkTttlioC3je__8577wxRieZqY3Abu7FOOdMnmYbcNI,10
21
- femtocrux-0.3.0.dist-info/RECORD,,
17
+ femtocrux-0.4.1.dist-info/LICENSE,sha256=eN9ZI1xHjUmFvN3TEeop5kBGXRUBfbsl55KBNBYYFqI,36
18
+ femtocrux-0.4.1.dist-info/METADATA,sha256=JjAykujLMZmW7rGL2Rx9HfmKkrk8UM8D6VH7S92ueZQ,3536
19
+ femtocrux-0.4.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
20
+ femtocrux-0.4.1.dist-info/top_level.txt,sha256=BkTttlioC3je__8577wxRieZqY3Abu7FOOdMnmYbcNI,10
21
+ femtocrux-0.4.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5